import torch
import time
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchsummary import summary

import torchvision
from torchvision import datasets, transforms
import numpy as np
import os
import scipy.io as io
from config import opt
from models import *
from path import *
from attack import *
from cleverhans.torch.attacks.projected_gradient_descent import projected_gradient_descent, fast_gradient_method


def make_device(device_no):
    os.environ["CUDA_VISIBLE_DEVICES"] = str(device_no)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    return device


def make_data_loader(data_name, root="./data", batch_size=64, shuffle=True):
    norm_mean = 0
    norm_var = 1

    try:   # 对数据进行随机化操作，增强数据的利用性
        if data_name == "cifar10":
            transform_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((norm_mean, norm_mean, norm_mean), (norm_var, norm_var, norm_var)),
            ])
            transform_test = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((norm_mean, norm_mean, norm_mean), (norm_var, norm_var, norm_var)),
            ])
            train_dataset = datasets.CIFAR10(root, train=True, download=True, transform=transform_train)
            test_dataset = datasets.CIFAR10(root, train=False, download=True, transform=transform_test)
        elif data_name == "mnist":
            transform_train = transforms.Compose([
                transforms.RandomCrop(28, padding=4),
                transforms.ToTensor(),
            ])
            transform_test = transforms.Compose([
                transforms.ToTensor(),
            ])
            train_dataset = datasets.MNIST(root, train=True, download=True, transform=transform_train)
            test_dataset = datasets.MNIST(root, train=False, download=True, transform=transform_test)
        elif data_name == "cifar100":
            transform_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((norm_mean, norm_mean, norm_mean), (norm_var, norm_var, norm_var)),
            ])
            transform_test = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((norm_mean, norm_mean, norm_mean), (norm_var, norm_var, norm_var)),
            ])
            train_dataset = datasets.CIFAR100(root, train=True, download=True, transform=transform_train)
            test_dataset = datasets.CIFAR100(root, train=False, download=True, transform=transform_test)
        else:
            raise ValueError("Dataset \'{}\' is not used by us".format(data_name))

    except ValueError as e:
        print("ValueError", repr(e))

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle)

    return train_loader, test_loader


def make_model(model_name, data_name, device):
    try:
        if data_name == "cifar10":
            nb_classes = 10
            if model_name == 'resnet18':
                return ResNet18(nb_classes).to(device)
            elif model_name == 'preactresnet18':
                return PreActResNet18(nb_classes).to(device)
            elif model_name == 'wideresnet28_10':
                return wideresnet28_10(nb_classes).to(device)
            elif model_name == 'wideresnet34_10':
                return wideresnet34_10(nb_classes).to(device)
            else:
                return VGG('VGG16').to(device)
                # return VGG16().to(device)
        elif data_name == "mnist":
            return Lenet( ).to(device)

        elif data_name == "cifar100":
            nb_classes = 100
            if model_name == 'resnet34':
                return ResNet34(nb_classes).to(device)
            else:
                return PreActResNet34(nb_classes).to(device)

        else:
            raise ValueError("Model \"{}\" is not used by us".format(model_name))
    except ValueError as e:
        print("Value Error!", repr(e))


def train(**kwargs):
    opt.parse(kwargs)
    model_name, data_name, train_mode, delta, device_no, batch_size, max_epoch, JR_lamda, save_note, lamda, learning_rate = \
        opt.model_name, opt.data_name, opt.train_mode, opt.delta, opt.device_no, opt.batch_size, opt.max_epoch, opt.JR_lamda, opt.save_note, opt.lamda, opt.learning_rate

    device = make_device(device_no)

    train_loader, test_loader = make_data_loader(data_name, "./data", batch_size)

    train_image = torch.tensor(train_loader.dataset.data)
    if data_name != 'mnist':
        train_image = train_image.permute(0, 3, 1, 2)
    train_image = train_image / 255 #pixls to[0,1]
    model = make_model(model_name, data_name, device)
    start_time = time.time()

    # save path
    cpkt_save_path = get_ckpt_save_path(train_mode, model_name, data_name, JR_lamda)

    #求数据矩阵的特征值分解, feature 的行是特征

    x_train_mean = torch.mean(train_image.float(), 0)

    # print(x_train_mean)

    pca_result = data_name + '_pca.pt'
    if os.path.exists(pca_result):
        evecs = torch.load(pca_result)
        evecs = evecs['evecs']
        # print(evecs)

    else:
        data_matrix = train_image - x_train_mean
        data_matrix = data_matrix.view(data_matrix.shape[0], -1)
        cov = torch.matmul(data_matrix.T, data_matrix) / (data_matrix.shape[0] - 1)
        print(cov.shape)
        (evals, evecs) = torch.eig(cov, eigenvectors=True)
        evecs = evecs.T
        torch.save({'evals': evals, 'evecs': evecs}, pca_result)
        # print(evecs)

    evecs = evecs.to(device)

    if data_name == "cifar10" or data_name == "cifar100":
        input_shape = (3, 32, 32)
    elif data_name == "mnist":
        input_shape = (1, 28, 28)
    else:
        raise ValueError("{} is not a dataset name".format(data_name))
    print(summary(model, input_shape))
    # if model_name == 'vgg':
    # optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)#, weight_decay=0.0002)
    # else:
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0002)
    adversary = LinfPGDAttack(model)
    criterion = nn.CrossEntropyLoss()
    attack_setting = {'cifar10': [8 / 255, 2 / 255, 10],
                      'cifar100': [8 / 255, 2 / 255, 10],
                      'mnist': [0.3, 0.01, 40]}
    epsilon = attack_setting[data_name][0]
    alpha = attack_setting[data_name][1]
    iter = attack_setting[data_name][2]

    x_train_mean = x_train_mean.to(device)

    def noised_small_scale_feature(x):
        images = x - x_train_mean
        img = images.view(images.shape[0], -1)
        coef_img = torch.matmul(evecs, img.T)
        # adjusted_img = torch.zeros_like(x)
        noise = torch.zeros_like(coef_img).uniform_(-epsilon, epsilon)
        noised_img = torch.matmul(evecs.T, noise).T
        noised_img = noised_img.view(x.shape)
        adjusted_img = x + noised_img
        adjusted_img = torch.clamp(adjusted_img, 0, 1)

        return adjusted_img

    # training method
    def train_std(inputs, targets):
        model.train()
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, targets)

        loss.backward()
        optimizer.step()

        predicted = outputs.max(1)[1]

        train_correct = predicted.eq(targets).sum().item()
        train_loss = loss.item()
        return train_correct, train_loss * len(targets)

    def remeve_small_scale(inputs, targets):
        model.train()
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()

        rm_inputs = noised_small_scale_feature(inputs)
        outputs = model(rm_inputs)
        loss = criterion(outputs, targets)

        loss.backward()
        optimizer.step()

        predicted = outputs.max(1)[1]

        train_correct = predicted.eq(targets).sum().item()
        train_loss = loss.item()
        return train_correct, train_loss * len(targets)

    def fgsm_at(inputs, targets):
        model.train()
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()

        adv = adversary.fgsm_perturb(inputs, targets, epsilon, alpha=epsilon)
        adv = adv.detach()
        adv_outputs = model(adv)
        loss = criterion(adv_outputs, targets)

        loss.backward()
        optimizer.step()

        predicted = adv_outputs.max(1)[1]

        train_correct = predicted.eq(targets).sum().item()
        train_loss = loss.item()
        return train_correct, train_loss * len(targets)

    def fat(inputs, targets):
        model.train()
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()

        adv = adversary.fgsm_perturb(inputs, targets, epsilon, alpha=1.25*epsilon, data_init='uniform')
        adv = adv.detach()
        adv_outputs = model(adv)
        loss = criterion(adv_outputs, targets)

        loss.backward()
        optimizer.step()

        predicted = adv_outputs.max(1)[1]

        train_correct = predicted.eq(targets).sum().item()
        train_loss = loss.item()
        return train_correct, train_loss * len(targets)

    def pgd_at(inputs, targets):
        model.train()
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()

        # alpha = epsilon / 2
        # iter = 2
        adv = adversary.pgd_perturb(inputs, targets, epsilon, alpha, iter)

        # adv = projected_gradient_descent(model, inputs, epsilon, alpha, iter, np.inf, 0, 1, targets, sanity_checks=False)
        adv = adv.detach()
        adv_outputs = model(adv)
        loss = criterion(adv_outputs, targets)

        loss.backward()
        optimizer.step()

        predicted = adv_outputs.max(1)[1]

        train_correct = predicted.eq(targets).sum().item()
        train_loss = loss.item()
        return train_correct, train_loss * len(targets)

    def ex_fgsm(inputs, targets):
        model.train()
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()

        # if data_name == 'mnist':
        #     epsilon =
        rm_inputs = noised_small_scale_feature(inputs)
        if data_name == 'mnist':
            rm_inputs.requires_grad_()
            logits = model(rm_inputs)
            loss = F.cross_entropy(logits, targets)
            grad = torch.autograd.grad(loss, [rm_inputs])[0]
            grad = grad.detach()
            delta = epsilon * torch.sign(grad)
            delta = delta.detach()
            adv = torch.clamp(inputs + delta, 0, 1)
        else:
            adv = adversary.fgsm_perturb(rm_inputs, targets, epsilon, alpha=epsilon)
            adv = adv.detach()

        adv_outputs = model(adv)
        loss = criterion(adv_outputs, targets)

        loss.backward()
        optimizer.step()

        predicted = adv_outputs.max(1)[1]

        train_correct = predicted.eq(targets).sum().item()
        train_loss = loss.item()
        return train_correct, train_loss * len(targets)

    def fgsm_gradalign(inputs, targets):
        model.train()
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()

        adv = adversary.fgsm_perturb(inputs, targets, epsilon, alpha=epsilon)
        adv = adv.detach()
        adv_outputs = model(adv)
        loss_org = criterion(adv_outputs, targets)
        x_noise = inputs + torch.zeros_like(inputs).uniform_(-epsilon, epsilon)
        x_noise = torch.clamp(x_noise, 0, 1)

        x_noise.requires_grad = True
        logits = model(x_noise)
        loss_noise = F.cross_entropy(logits, targets)
        grad_noise = torch.autograd.grad(loss_noise, [x_noise])[0]
        x_noise.requires_grad = False

        inputs.requires_grad = True
        logits = model(inputs)
        loss_x = F.cross_entropy(logits, targets)
        grad = torch.autograd.grad(loss_x, [inputs])[0]
        inputs.requires_grad = False

        grad = grad.view(inputs.shape[0], -1)
        grad_noise = grad_noise.view(inputs.shape[0], -1)

        loss_regu = 1 - torch.mean(F.cosine_similarity(grad, grad_noise, 1))

        loss = loss_org + JR_lamda * loss_regu

        loss.backward()
        optimizer.step()

        predicted = adv_outputs.max(1)[1]

        train_correct = predicted.eq(targets).sum().item()
        train_loss = loss_org.item()
        return train_correct, train_loss * len(targets)


    def pro_small_scale(a):
        a = a.view(a.shape[0], -1)
        a_norm = torch.sum(a.mul(a), 1)
        tmp = torch.matmul(a, evecs[300:a.shape[1]].T)
        tmp_norm = torch.sum(tmp.mul(tmp), 1)
        a_p = tmp_norm.div(a_norm + 1e-10)
        a_p = torch.sum(a_p)
        a_p = a_p.item()
        return a_p

    def test(inputs, targets):
        model.eval()
        inputs, targets = inputs.to(device), targets.to(device)

        #pgd test
        # adv_pgd = adversary.pgd_perturb(inputs, targets, epsilon, alpha, iter)
        adv_pgd = projected_gradient_descent(model, inputs, epsilon, alpha, iter, np.inf, 0, 1, targets, sanity_checks=False)
        pgd_preturb = adv_pgd - inputs
        with torch.no_grad():
            adv_outputs = model(adv_pgd)
            predicted = adv_outputs.max(1)[1]
            pgd_correct = predicted.eq(targets).sum().item()

        # fgsm test
        fgsm = fast_gradient_method(model, inputs, epsilon, np.inf, 0, 1, targets)
        fgsm_preturb = fgsm - inputs

        with torch.no_grad():
            fgsm_outputs = model(fgsm)
            predicted = fgsm_outputs.max(1)[1]
            fgsm_correct = predicted.eq(targets).sum().item()

        with torch.no_grad():
            # clean test
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            test_loss = loss.item()
            predicted = outputs.max(1)[1]
            test_correct = predicted.eq(targets).sum().item()

            inputs.requires_grad = True
            with torch.enable_grad():
                outputs = model(inputs)
                y_one_hot = F.one_hot(targets, outputs.shape[1])
                logit_true_label_loss = torch.sum(outputs * y_one_hot)
            grad = torch.autograd.grad(logit_true_label_loss, [inputs])[0]
            grad = grad.detach()
            inputs.requires_grad = False

            g_pro = pro_small_scale(grad)
            pgd_pro = pro_small_scale(pgd_preturb)
            fgsm_pro = pro_small_scale(fgsm_preturb)

            pgd_preturb = pgd_preturb.view(inputs.shape[0], -1)
            fgsm_preturb = fgsm_preturb.view(inputs.shape[0], -1)

            simi = F.cosine_similarity(pgd_preturb, fgsm_preturb, 1)
            simi = torch.sum(simi)
            simi = simi.item()
        return test_loss * len(targets), test_correct, pgd_correct, fgsm_correct, g_pro, pgd_pro, fgsm_pro, simi

    try:
        if train_mode == "std":
            train_step = train_std
        elif train_mode == "adv":
            train_step = pgd_at
        elif train_mode == "fat":
            train_step = fat
        elif train_mode == "fgsm":
            train_step = fgsm_at
        elif train_mode == "small_scale":
            train_step = remeve_small_scale
        elif train_mode == "ex_fgsm":
            train_step = ex_fgsm
        elif train_mode == "grad_align":
            train_step = fgsm_gradalign
        else:
            raise ValueError("{} is not a train mode".format(train_mode))
    except ValueError as e:
        print("引发异常：", repr(e))

    def adjust_learning_rate(optimizer, epoch):
        lr = learning_rate
        if epoch >= 40:
            lr /= 10
        if epoch >= 60:
            lr /= 10
        if epoch >= 80:
            lr /= 10
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        print(lr)

    val_acc_max = 0
    adv_acc_max = 0

    acc_ = []
    t_pgd_acc = []
    t_fgsm_acc = []
    t_acc = []
    loss_ = []
    t_loss = []
    gbais = []
    pgdpro = []
    fgsmpro = []
    cos_ = []

    best_adv = cpkt_save_path + '/' + 'best_adv' + '/'
    if not os.path.isdir(best_adv):
        os.makedirs(best_adv)
    best_acc = cpkt_save_path + '/' + 'best_acc' + '/'
    if not os.path.isdir(best_acc):
        os.makedirs(best_acc)

    epoch_result = cpkt_save_path + '/' + 'epoch_result' + '/'
    if not os.path.isdir(epoch_result):
        os.makedirs(epoch_result)

    detail_path = cpkt_save_path + '/' + 'detail' + '/'
    if not os.path.isdir(detail_path):
        os.makedirs(detail_path)

    for epoch in range(0, max_epoch):
        adjust_learning_rate(optimizer, epoch)
        # Reset the metrics at the start of the next epoch
        train_l = 0
        test_l = 0
        train_acc = 0
        test_acc = 0

        test_pgd_acc = 0
        test_fgsm_acc = 0
        pgd_p = 0
        fgsm_p = 0
        g_p = 0
        cos_sim = 0

        print('\nEopch:' + str(epoch + 1) + '/' + str(max_epoch))

        train_pbar = tqdm(total=len(train_loader), ncols=100)

        train_num = 0
        for images, labels in train_loader:
            train_correct, train_loss = train_step(images, labels)
            train_num += len(labels)
            train_acc += train_correct
            train_l += train_loss
            train_pbar.update(1)
            train_pbar.set_postfix({"Train Loss": train_l / train_num, "Train Acc": train_acc / train_num})

        train_pbar.close()

        test_pbar = tqdm(total=len(test_loader), ncols=100)
        test_num = 0
        for test_images, test_labels in test_loader:
            test_loss, test_correct, pgd_correct, fgsm_correct, g_pro, pgd_pro, fgsm_pro, cos_similarity = test(test_images, test_labels)

            test_num += len(test_labels)
            test_acc += test_correct
            test_l += test_loss

            test_pgd_acc += pgd_correct
            test_fgsm_acc += fgsm_correct

            g_p += g_pro
            pgd_p += pgd_pro
            fgsm_p += fgsm_pro

            cos_sim += cos_similarity

            test_pbar.update(1)
            test_pbar.set_postfix({"Test Loss": test_l / test_num, "Test Acc": test_acc / test_num})
        test_pbar.close()

        print(test_num)
        print(
            "\nTest acc on FGM adversarial examples (%): {:.3f}".format(
                test_fgsm_acc / test_num),
            "Fgsm_pro: {:.3f}".format(fgsm_p / test_num), "Grad_pro: {:.3f}".format(g_p / test_num)
        )
        print(
            "Test acc on PGD adversarial examples (%): {:.3f}".format(
                test_pgd_acc / test_num), "Pgd_pro: {:.3f}".format(pgd_p / test_num),
            "Cos: {:.3f}".format(cos_sim / test_num)
        )

        loss_.append(train_l / train_num)
        acc_.append(train_acc / train_num)

        t_loss.append(test_l / test_num)
        t_acc.append(test_acc / test_num)

        t_pgd_acc.append(test_pgd_acc / test_num)
        t_fgsm_acc.append(test_fgsm_acc / test_num)

        gbais.append(g_p / test_num)
        pgdpro.append(pgd_p / test_num)
        fgsmpro.append(fgsm_p / test_num)
        cos_.append(cos_sim / test_num)

        print("\n")
        if test_acc / test_num > val_acc_max:
            val_acc_max = test_acc / test_num
            torch.save(model.state_dict(), best_acc + 'model.ckpt')
            print("New best ckpt is saved to {}".format(best_acc))
        else:
            print("Clean acc didn't improve from {}%".format(100 * val_acc_max))

        if test_pgd_acc / test_num > adv_acc_max:
            adv_acc_max = test_pgd_acc / test_num
            torch.save(model.state_dict(), best_adv + 'model.ckpt')
            print("New best ckpt is saved to {}".format(best_adv))
        else:
            print("Pgd acc didn't improve from {}%".format(100 * adv_acc_max))

        pt = epoch_result + str(epoch) + '/'
        if not os.path.isdir(pt):
            os.makedirs(pt)
        torch.save(model.state_dict(), pt + 'model.ckpt')
        print("New best ckpt is saved to {}".format(pt))

        np.save(detail_path + '/' + 'train_loss.npy', loss_)
        np.save(detail_path + '/' + 'train_acc.npy', acc_)
        np.save(detail_path + '/' + 'pgd_test_acc.npy', t_pgd_acc)
        np.save(detail_path + '/' + 'fgsm_test_acc.npy', t_fgsm_acc)
        np.save(detail_path + '/' + 'test_loss.npy', t_loss)
        np.save(detail_path + '/' + 'test_acc.npy', t_acc)
        np.save(detail_path + '/' + 'grad_bias.npy', gbais)
        np.save(detail_path + '/' + 'pgd_pro.npy', pgdpro)
        np.save(detail_path + '/' + 'fgsm_pro.npy', fgsmpro)
        np.save(detail_path + '/' + 'cos.npy', cos_)
        np.save(detail_path + '/' + 'time.npy', time.time() - start_time)


if __name__ == "__main__":
    # fire.Fire()
    # train(model_name='resnet18', data_name="cifar10", train_mode="std", device_no=0, batch_size=128)

    # train(model_name='resnet18', data_name="cifar10", train_mode="small_scale", device_no=0, batch_size=128, max_epoch=80)
    # train(model_name='resnet18', data_name="cifar10", train_mode="fgsm", device_no=0, batch_size=128, learning_rate=0.1, max_epoch=80)
    # train(model_name='resnet18', data_name="cifar10", train_mode="grad_align", device_no=0, batch_size=128, JR_lamda=0.2, max_epoch=80)
    # train(model_name='resnet18', data_name="cifar10", train_mode="adv", device_no=1, batch_size=128, max_epoch=80, learning_rate=0.01)
    train(model_name='resnet18', data_name="cifar10", train_mode="ex_fgsm", device_no=3, batch_size=128, max_epoch=80)



    # train(model_name='vgg', data_name="cifar10", train_mode="std", device_no=1, batch_size=128, learning_rate=0.01)
    # train(model_name='vgg', data_name="cifar10", train_mode="adv", device_no=0, batch_size=128, learning_rate=0.01)

    # train(model_name='vgg', data_name="cifar10", train_mode="small_scale", device_no=1, batch_size=128, learning_rate=0.01)
    # train(model_name='vgg', data_name="cifar10", train_mode="grad_align", device_no=1, batch_size=128, JR_lamda=0.2, learning_rate=0.01)
    # train(model_name='vgg', data_name="cifar10", train_mode="fgsm", device_no=3, batch_size=128)
    # train(model_name='vgg', data_name="cifar10", train_mode="ex_fgsm", device_no=1, batch_size=128, learning_rate=0.01)



    # train(model_name='resnet34', data_name="cifar100", train_mode="std", device_no=2, batch_size=128)

    # train(model_name='resnet34', data_name="cifar100", train_mode="small_scale", device_no=3, batch_size=128)
    # train(model_name='resnet34', data_name="cifar100", train_mode="ex_fgsm", device_no=0, batch_size=128)
    # train(model_name='resnet34', data_name="cifar100", train_mode="grad_align", device_no=3, batch_size=128, JR_lamda=0.2, max_epoch=80)
    # train(model_name='resnet34', data_name="cifar100", train_mode="adv", device_no=0, batch_size=128, learning_rate=5e-3, max_epoch=80)
    # train(model_name='resnet34', data_name="cifar100", train_mode="fgsm", device_no=3, batch_size=128, learning_rate=0.1, max_epoch=80)




    # train(model_name='wideresnet34_10', data_name="cifar10", train_mode="std", device_no=0, batch_size=128, max_epoch=80)


    # train(model_name='wideresnet34_10', data_name="cifar10", train_mode="small_scale", device_no=1, batch_size=128, max_epoch=80)
    # train(model_name='wideresnet34_10', data_name="cifar10", train_mode="ex_fgsm", device_no=2, batch_size=128, max_epoch=80)
    # train(model_name='wideresnet34_10', data_name="cifar10", train_mode="adv", device_no=3, batch_size=128, max_epoch=80)

    # train(model_name='wideresnet34_10', data_name="cifar10", train_mode="fgsm", device_no=1, batch_size=128, max_epoch=80)

    # train(model_name='wideresnet34_10', data_name="cifar10", train_mode="grad_align", device_no=2, batch_size=128, JR_lamda=0.2, max_epoch=80)





    # train(model_name='preactresnet18', data_name="cifar10", train_mode="std", device_no=2, batch_size=128)
    # train(model_name='preactresnet18', data_name="cifar10", train_mode="ex_fgsm", device_no=2, r=500, batch_size=128)
    # train(model_name='preactresnet18', data_name="cifar10", train_mode="grad_align", device_no=2, batch_size=128, JR_lamda=1)
    # train(model_name='preactresnet18', data_name="cifar10", train_mode="adv", device_no=2, batch_size=128)
    # train(model_name='preactresnet18', data_name="cifar10", train_mode="fgsm", device_no=2, batch_size=128, JR_lamda=1)
    # train(model_name='preactresnet18', data_name="cifar10", train_mode="small_scale", device_no=3, batch_size=128)

    # train(model_name='lenet', data_name="mnist", train_mode="small_scale", device_no=3, batch_size=128, learning_rate=0.01)
    # train(model_name='lenet', data_name="mnist", train_mode="std", device_no=3, batch_size=128, learning_rate=0.01)
    # train(model_name='lenet', data_name="mnist", train_mode="adv", device_no=0, batch_size=128, learning_rate=0.001)
    #
    # train(model_name='lenet', data_name="mnist", train_mode="ex_fgsm", device_no=1, batch_size=128, learning_rate=0.01)
    train(model_name='lenet', data_name="mnist", train_mode="fgsm", device_no=2, batch_size=128, learning_rate=0.001)
    # train(model_name='lenet', data_name="mnist", train_mode="fat", device_no=2, batch_size=128, learning_rate=0.01)
    # train(model_name='lenet', data_name="mnist", train_mode="grad_align", device_no=3, batch_size=128, JR_lamda=0.2, learning_rate=0.01)


